""" Build the state dataset including rig utilization

"""
import pandas as pd
import numpy as np
import copy
import sys
sys.path.append('./')

from src.data_py import utils

#%% CONFIGURE ---------------------------------------------------------------------------
spec_bounds = {'low': [0, 200.1], 'mid': [200.1, 299.1], 'high': [299.1, 500]}

# Get a reasonable date range for expanding by each rig-date (too long makes it slow)
dates = pd.date_range('1995-01-01', '2015-01-01')

# Find utilization vs removed statuses
utilization_status = [
    "Drilling",
    "Workover",
    "Production",
    "Inspection",
    "Waiting on Loc",
    "Enroute",
    "Accommodation",
    # "Modification"
]

nonutilization_status = [
    "Cold Stacked",
    "Ready Stacked",
    # "Modification"
]

remove_status = [
    "Under Construction",
    # "Cold Stacked",
    "Modification"
]

#%% READ IN DATA ------------------------------------------------------------------------
df_init = pd.read_excel("./data_py/raw/contracts/dayrates_rigzone_new.xls", skiprows=[0, 1])
df = copy.copy(df_init)
df.columns = df.columns.str.lower()
df_oil = pd.read_excel(
    './data_py/raw/gas_prices/PET_PRI_SPT_S1_W.xls',
    sheet_name=1,
    header=2,
    names=['date', 'cushing', 'brent']
)
df_oil['date'] = pd.to_datetime(df_oil['date'], format='%Y-%m-%d')

df_gas = pd.read_excel(
    './data_py/raw/gas_prices/RNGWHHDw.xls',
    sheet_name=1,
    header=2,
    names=['date', 'gas']
)
df_gas['date'] = pd.to_datetime(df_gas['date'], format='%Y-%m-%d')

# Deflate
df_deflator = pd.read_csv('data_py/processed/deflator_month.csv', index_col=[0])
df_deflator['date'] = pd.to_datetime(df_deflator['date'])

#%% CLEAN UP OIL AND GAS PRICES ---------------------------------------------------------
# Convert to month
df_gas_by_time = dict()
df_gas_by_time['month'] = (
    df_gas
    .groupby(pd.Grouper(key='date', freq='M'))
    .mean()
    .reset_index()
)
df_gas_by_time['month'] = (
    df_gas_by_time['month']
    .loc[(
        (df_gas_by_time['month']['date'].dt.year >= 2000)
        & (df_gas_by_time['month']['date'].dt.year <= 2015)
    )]
    .assign(month=pd.to_datetime(df_gas_by_time['month']['date'].dt.strftime('%Y-%m')))
    .drop(columns=['date'])
    .merge(
        df_deflator,
        left_on='month',
        right_on='date',
        how='left'
    )
)
df_gas_by_time['month']['gas'] = 100 * df_gas_by_time['month']['gas'] / df_gas_by_time['month']['gdpdef']

df_oil_month = (
    df_oil
    .groupby(pd.Grouper(key='date', freq='M'))
    .mean()
    .reset_index()
)
df_oil_month = (
    df_oil_month
    .loc[(
        (df_oil_month['date'].dt.year >= 2000)
        & (df_oil_month['date'].dt.year <= 2009)
    )]
    .assign(month=pd.to_datetime(df_oil_month['date'].dt.strftime('%Y-%m')))
    .drop(columns=['date'])
    .merge(
        df_deflator,
        left_on='month',
        right_on='date',
        how='left'
    )
)
df_oil_month['brent'] = 100 * df_oil_month['brent'] / df_oil_month['gdpdef']
df_oil_month['cushing'] = 100 * df_oil_month['cushing'] / df_oil_month['gdpdef']

# Save price datasets
df_oil_month.to_csv("./data_py/processed/oil_prices_deflated.csv")
df_gas_by_time['month'].to_csv("./data_py/processed/gas_prices_deflated.csv")

# Finally, use interpolation to get a fortnightly df_gas price
df_gas_by_time['fortnight'] = copy.copy(df_gas_by_time['month'])
df_gas_by_time['fortnight'] = (
    df_gas_by_time['fortnight']
    .set_index('date')['gas']
    .resample('D')
    .interpolate(method='linear', limit_direction='forward')
    .reset_index()
)
df_gas_by_time['fortnight'].to_csv("./data_py/processed/gas_prices_deflated_fortnight.csv")

#%% CUT ---------------------------------------------------------------------------------
df = df[
    (df['rig type'] == "Jackup")
    & (df['region'] == 'N. America - US GOM')
]

# Aggregate spec
df['spec'] = np.nan
for spec in ['low', 'mid', 'high']:
    df.loc[(
        (df['max wd'] >= spec_bounds[spec][0])
        & (df['max wd'] < spec_bounds[spec][1])
    ), 'spec'] = spec

# Ensure everything is in pd.datetime format
df['stat start'] = pd.to_datetime(df['stat start']).dt.date
df['stat end'] = pd.to_datetime(df['stat end']).dt.date
df = df.sort_values(['rig name', 'stat start', 'stat end'])


#%% GET A DATASET OF ONLY CONTRACTS FROM THE RIGZONE DATA -------------------------------
df_contracts = (
    df
    .drop_duplicates(subset=['rig name', 'k start'])
    .dropna(subset=['k start'])
    .assign(new_contract=lambda x: (
        (x['operator'] != x['operator'].shift(1))
        & (x['rig name'] == x['rig name'].shift(1))
    ))
    .assign(duration=lambda x: (x['k end'] - x['k start']).dt.days)
)

# Assign durations
df_contracts['tau'] = np.nan
df_contracts.loc[df_contracts['duration'] < 75, 'tau'] = 1
df_contracts.loc[
    (df_contracts['duration'] >= 75) & (df_contracts['duration'] < 105), 'tau'] = 2
df_contracts.loc[(df_contracts['duration'] >= 105), 'tau'] = 3
df_contracts['contract_start_month'] = \
    pd.to_datetime(df_contracts['k start'].dt.strftime('%Y-%m'))

s = np.where(df_contracts['k start'].dt.day < 15, '-01', '-15')
df_contracts['contract_start_fortnight'] = pd.to_datetime(df_contracts['k start'].dt.strftime('%Y-%m') + s)
df_contracts.to_csv(f'./data_py/processed/rigzone_contracts_for_bootstrap.csv')


#%% GENERATE SOME AGGREGATE STATISTICS AT THE RIGTYPE-DAILY LEVEL -----------------------
df.columns = [i.lower() for i in df.columns]
df['status_id'] = range(len(df))
df.set_index(['status_id'], inplace=True)
df.index.names = ['status_id']
df.drop_duplicates(['rig name', 'stat start', 'stat end'], inplace=True)

df = df.dropna(subset=['stat start', 'stat end'])
df.to_csv(f'./data_py/processed/rigzone_status_for_bootstrap.csv')

df_state_by_time, df_agg_by_time = utils.construct_states(
        df, df_contracts, df_gas_by_time, utilization_status, nonutilization_status)

#%%
for i in df_state_by_time:
    df_state_by_time[i].to_csv(f'./data_py/temp/07_build_states/states_{i}_long.csv')

    # Shorten states for estimation...
    df_state_by_time[i] = df_state_by_time[i][df_state_by_time[i]['date'] <= pd.to_datetime('2009-12-31')]
    df_agg_by_time[i] = df_agg_by_time[i][df_agg_by_time[i].index.get_level_values(1) <= pd.to_datetime('2009-12-31')]
    df_state_by_time[i].to_csv(f'./data_py/temp/07_build_states/states_{i}.csv')
    df_gas_by_time[i] = df_gas_by_time[i][df_gas_by_time[i]['date'].dt.year <= 2009]

#%% GET DATASET TO DO MULTINOMIAL LOGIT REGRESSION --------------------------------------
df_mnl_by_agg = dict()
for i in ['fortnight', 'month']:
    df_mnl_by_agg[i] = copy.copy(df_agg_by_time[i])
    df_mnl_by_agg[i]['tau'] = 0
    df_mnl_by_agg[i] = df_mnl_by_agg[i].set_index('tau', append=True)

    n_matches_month_by_tau = (
        df_contracts
        .groupby(['spec', f'contract_start_{i}', 'tau'])['new_contract']
        .sum()
    )
    df_agg_by_tau = (
        n_matches_month_by_tau
        .append(df_mnl_by_agg[i]['n_unemployed'].round(0).astype(int))
        .reset_index()
    )
    df_mnl = (
        df_agg_by_tau
        .reindex(df_agg_by_tau.index.repeat(df_agg_by_tau[0]))
        .rename({'spec': 'rig_spec', f'contract_start_{i}': 'date'}, axis=1)
        .reset_index()
        .drop([0, 'index'], axis=1)
    )
    df_mnl.to_csv(f'./data_py/temp/07_build_states/mnl_{i}.csv')
    df_mnl.to_csv(f'./data_py/processed/mnl_{i}.csv')
